{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d6e72215-501e-4bb0-ada9-7fb61849b524",
   "metadata": {},
   "source": [
    "# Jupyter执行LLaMA Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "62202b1e-ca0e-493e-8809-6e8deeb7c6bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='2,3,4,5'\n",
    "print(torch.cuda.device_count())\n",
    "import transformers\n",
    "from peft import PeftModel\n",
    "from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c6eed29-07b0-49f5-a314-ea73313f1118",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/jiw203/wjn/InstructGraph/examples/demo\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e4c24d7-f6ed-4c27-9bba-851564027054",
   "metadata": {},
   "source": [
    "### 加载LLaMA与PEFT模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a674b2ec-215e-44ee-879b-fbc3d7b063a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"/home/jiw203/wjn/pre-trained-lm/Llama-2-7b-hf/\"\n",
    "peft_model = \"/home/jiw203/wjn/InstructGraph/output/instruction_tuning/fsdp_peft_1250k/llama2-peft\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "449eaa15-5c6a-47e7-92d5-ebe8c9e86f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to load the main model for text generation\n",
    "def load_model(model_name, quantization):\n",
    "    model = LlamaForCausalLM.from_pretrained(\n",
    "        model_name,\n",
    "        return_dict=True,\n",
    "        load_in_8bit=quantization,\n",
    "        device_map=\"auto\",\n",
    "        low_cpu_mem_usage=True,\n",
    "    )\n",
    "    return model\n",
    "\n",
    "\n",
    "# Function to load the PeftModel for performance optimization\n",
    "def load_peft_model(model, peft_model):\n",
    "    peft_model = PeftModel.from_pretrained(model, peft_model)\n",
    "    return peft_model\n",
    "\n",
    "# Loading the model from config to load FSDP checkpoints into that\n",
    "def load_llama_from_config(config_path):\n",
    "    model_config = LlamaConfig.from_pretrained(config_path) \n",
    "    model = LlamaForCausalLM(config=model_config)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d51e532-8130-4eab-939b-2718b603c181",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc819392959b4772bf7431a932266e8b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = load_model(model_name_or_path, quantization=False)\n",
    "model = load_peft_model(model, peft_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "26cd638f-a644-4831-9716-7fbe8209f405",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModelForCausalLM(\n",
       "  (base_model): LoraModel(\n",
       "    (model): LlamaForCausalLM(\n",
       "      (model): LlamaModel(\n",
       "        (embed_tokens): Embedding(32000, 4096)\n",
       "        (layers): ModuleList(\n",
       "          (0-31): 32 x LlamaDecoderLayer(\n",
       "            (self_attn): LlamaAttention(\n",
       "              (q_proj): Linear(\n",
       "                in_features=4096, out_features=4096, bias=False\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.05, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=4096, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=4096, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "              (v_proj): Linear(\n",
       "                in_features=4096, out_features=4096, bias=False\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.05, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=4096, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=4096, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "              (rotary_emb): LlamaRotaryEmbedding()\n",
       "            )\n",
       "            (mlp): LlamaMLP(\n",
       "              (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "              (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "              (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
       "              (act_fn): SiLUActivation()\n",
       "            )\n",
       "            (input_layernorm): LlamaRMSNorm()\n",
       "            (post_attention_layernorm): LlamaRMSNorm()\n",
       "          )\n",
       "        )\n",
       "        (norm): LlamaRMSNorm()\n",
       "      )\n",
       "      (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1a0396fe-d82a-4feb-93b8-acea79a6bd40",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efba0e8e-fe4a-44d6-96cd-66b6b089199f",
   "metadata": {},
   "source": [
    "### 自定义输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "ad2f2e40-635b-4aad-9bf7-9958942dcbef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[    1,   887,   526,   263,  1781,  3983, 15299, 29889,   887,   817,\n",
      "           304,  2274,   278,  3414,  5023,   322,  5706,   263,  3983,  4086,\n",
      "           304,  1234,   278,  1139, 29889, 29871,    13,  9842, 29901,   313,\n",
      "         29875,   529,   976,   432, 29897,  2794,   393,  2943,   474,   322,\n",
      "          2943,   432,   526,  6631,   411,   385,   563,  1088,   287,  7636,\n",
      "         29889,   313, 29875,  1599,   432, 29897,  2794,   393,  2943,   474,\n",
      "           322,  2943,   432,   526,  6631,   411,   263, 10624,  7636, 29889,\n",
      "         29871,    13,  5398,  5023, 29901,  2183,   263,  2114,   950,  7134,\n",
      "          1139, 29892,  3113,  1348,  4331,   491,  4331, 29901, 29871, 29896,\n",
      "         29897,  1284,   278, 11261,  7855,   322,  5706,   263,  6590,  7134,\n",
      "          1014,  4262,   304,  4653,   278,  3186,  7134,  2472, 29892, 29871,\n",
      "         29906, 29897,   769,  5706,   263,  7291,  8252,   304,  8453,   445,\n",
      "         24481, 29892,   322, 29871, 29941, 29897,  7146,  1962,   278,  2186,\n",
      "          1234, 29889,  3940,   393, 29901, 29871, 29896, 29897,   278,  3983,\n",
      "           338,   263, 10624, 29899,  7915,   287,  3983,   322,   278,  1024,\n",
      "           338,   376, 17028,   950, 29899, 28385,  5485, 29899,  4262,  1642,\n",
      "         29871, 29906, 29897,   450,  5759,  3983,  4086,   881,   367,   263,\n",
      "           775, 29899,  4561,  3829, 29892,   322,   278, 18109, 11285,  3402,\n",
      "           508,   367, 13384,   408,   278,  1494, 29901,    13, 28956,    13,\n",
      "          9527, 29961,   978,   543, 17028,   950, 29899, 28385,  5485, 29899,\n",
      "          4262,  3108,   426,    13,  1678,  7855, 29918,  1761,   353,  6796,\n",
      "         12353,   613,  2023,  1385,    13,  1678, 21954, 29918,  1761,   353,\n",
      "           518,   703, 12353, 29908,  1599,   376, 12353,  1159, 29961, 23445,\n",
      "           543, 12353, 12436,  2023,  1385,    13, 29913,    13, 28956,    13,\n",
      "         29984, 29901,   825,   471,   278,  3971,   664,   393,   471,  3517,\n",
      "           297,   278,  3652,   443, 25047,  9946, 29973,    13, 29909, 29901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n"
     ]
    }
   ],
   "source": [
    "# instruction = (\n",
    "#     'You are a good graph reasoner. Give you a graph language that describes a graph structure and node information. You need to understand the graph and the task definition, and answer the question.\\nNote: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \\n```\\nGraph[name=\"factual-graph\"] {\\n    entity_list = [\\'Raid on Unadilla and Onaquaga\\', \\'New York\\'];\\n    triple_list = [(\"Raid on Unadilla and Onaquaga\" -> \"New York\")[relation=\"located in the administrative territorial entity\"], (\"Raid on Unadilla and Onaquaga\" -> \"New York\")[relation=\"location\"]];\\n}\\n```\\nTask definition: given an event title and a factual knowledge graph, generate an event narration.\\nQ: Please generate an event narration based on the knowledge graph and the title \"Raid on Unadilla and Onaquaga\".\\nA:'\n",
    "# )\n",
    "instruction = \"\"\"You are a good graph generator. You need to understand the task definition and generate a graph language to answer the question. \n",
    "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \n",
    "Task definition: given a factual knowledge question, please think step by step: 1) find the topic entity and generate a corresponding knowledge subgraph to express the world knowledge information, 2) then generate a thinking explanation to describe this reasoning, and 3) finally output the final answer. Note that: 1) the graph is a directed-weighted graph and the name is \"factual-knowledge-graph\". 2) The generated graph language should be a code-like structure, and the skeleton format can be expressed as the following:\n",
    "```\n",
    "Graph[name=\"factual-knowledge-graph\"] {\n",
    "    entity_list = [\"xxx\", ...];\n",
    "    triple_list = [(\"xxx\" -> \"xxx\")[relation=\"xxx\"], ...];\n",
    "}\n",
    "```\n",
    "Q: what was the written work that was previous in the series unnatural causes?\n",
    "A:\"\"\"\n",
    "batch = tokenizer(instruction, truncation=True, max_length=2048, return_tensors=\"pt\")\n",
    "print(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea55b93-6449-40e5-8cea-249bc74f2a9c",
   "metadata": {},
   "source": [
    "### 模型生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "1c0b5500-c373-466c-a309-c541bc5f0be9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are a good graph generator. You need to understand the task definition and generate a graph language to answer the question. \n",
      "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \n",
      "Task definition: given a factual knowledge question, please think step by step: 1) find the topic entity and generate a corresponding knowledge subgraph to express the world knowledge information, 2) then generate a thinking explanation to describe this reasoning, and 3) finally output the final answer. Note that: 1) the graph is a directed-weighted graph and the name is \"factual-knowledge-graph\". 2) The generated graph language should be a code-like structure, and the skeleton format can be expressed as the following:\n",
      "```\n",
      "Graph[name=\"factual-knowledge-graph\"] {\n",
      "    entity_list = [\"xxx\", ...];\n",
      "    triple_list = [(\"xxx\" -> \"xxx\")[relation=\"xxx\"], ...];\n",
      "}\n",
      "```\n",
      "Q: what was the written work that was previous in the series unnatural causes?\n",
      "A:I cannot answer the question directly. However, based on the world knowledge, the correct answer to the question is \"The Case of the Missing Will\"\n"
     ]
    }
   ],
   "source": [
    "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(\n",
    "        **batch,\n",
    "        max_new_tokens=500,\n",
    "        do_sample=False,\n",
    "        top_p=1.0,\n",
    "        temperature=1.0,\n",
    "        min_length=None,\n",
    "        use_cache=True,\n",
    "        top_k=50,\n",
    "        repetition_penalty=1.0,\n",
    "        length_penalty=1\n",
    "    )\n",
    "output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "print(output_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c387e354-4d42-4066-9d25-99976fb99569",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b23064c1-2722-4325-8a48-60cde1c4bfaf",
   "metadata": {},
   "source": [
    "## 模型生成能力初检"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a577541-d905-463a-afa3-a0522612c5bb",
   "metadata": {},
   "source": [
    "graph structure modeling\n",
    "- Connectivity：无法生成\n",
    "- Cycle：无法生成\n",
    "- Topological Sort：可以生成\n",
    "- Shortest Path：无法生成\n",
    "- Maximum Flow：无法生成\n",
    "- Bipartite Graph Matching：无法生成\n",
    "- Hamilton Path：可以生成，但内容不符\n",
    "- Graph Degree：无法生成\n",
    "\n",
    "graph caption generation:\n",
    "- wikipedia：没问题\n",
    "- webnlg：没问题\n",
    "- agenda：没问题\n",
    "- genwiki：没问题\n",
    "- EventNarrative：没问题\n",
    "- XAlign：没问题\n",
    "\n",
    "graph question answering：\n",
    "- pathquestions：无法生成\n",
    "- wc2014：无法生成\n",
    "- grailqa：无法生成\n",
    "- webquestion：无法生成\n",
    "\n",
    "graph node cls：\n",
    "- cora：没问题\n",
    "- citeseer：无法生成\n",
    "- pubmed：没问题\n",
    "- ogbn-arxiv：\n",
    "- ogbn-product：无法生成\n",
    "\n",
    "graph link prediction\n",
    "- wikidata：无法生成\n",
    "- FB15k-237：可以生成，但内容不符\n",
    "- ConceptNet：无法生成\n",
    "\n",
    "Graph Relevance Inspection：没问题\n",
    "\n",
    "Graph Collaboration Filtering：\n",
    "- moivelens：没问题\n",
    "- Amazon：无法生成\n",
    "- LastFM\n",
    "\n",
    "knowledge graph generation：（调整一下顺序，将passage放在q的前面）\n",
    "- wikidata：没问题，但存在重复生成问题\n",
    "- instructuie：没问题\n",
    "- instructkgc：没问题，但存在重复生成问题\n",
    "\n",
    "Structure Graph Generation：（调整一下顺序，将passage放在q的前面）\n",
    "没问题\n",
    "\n",
    "Graph thought modeling: 没问题，但是graph language 多了个[capacity=\"xxx\"]，且存在重复生成问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93cb4492-c34a-40d8-84bc-628025509731",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41565356-14fc-4325-8b9d-098c8690cd30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b84bbd59-8521-44c8-b25c-a0ffa5037f13",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 错误记录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e32d8a25-b27c-4c41-8666-fbda868b4ac7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nYou are a good graph reasoner. Give you a graph language that describe a graph structure and node infotmation. You need to understanding the graph and the task definition, and answer the question.\\nNote: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with an directed edge. \\n```\\nGraph[name=\"maximum-flow\"] {\\n    node_list = [0, 1, 2, 3, 4, 5];\\n    edge_list = [(0 -> 4)[capacity=3], (0 -> 1)[capacity=6], (0 -> 2)[capacity=3], (1 -> 2)[capacity=4], (1 -> 0)[capacity=4], (2 -> 4)[capacity=9], (2 -> 3)[capacity=7], (3 -> 4)[capacity=4], (3 -> 0)[capacity=5], (4 -> 1)[capacity=4], (4 -> 2)[capacity=4], (4 -> 0)[capacity=2], (5 -> 0)[capacity=3]];\\n}\\n```\\nTask definition: calculate the maximum flow between two nodes in the graph.\\nQ: What is the maximum flow from node 2 to node 4?\\nA:\\n\\n```\\nmax_flow_from_node_2_to_node_4 = 12\\n```\\n'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "You are a good graph reasoner. Give you a graph language that describe a graph structure and node infotmation. You need to understanding the graph and the task definition, and answer the question.\n",
    "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with an directed edge. \n",
    "```\n",
    "Graph[name=\"maximum-flow\"] {\n",
    "    node_list = [0, 1, 2, 3, 4, 5];\n",
    "    edge_list = [(0 -> 4)[capacity=3], (0 -> 1)[capacity=6], (0 -> 2)[capacity=3], (1 -> 2)[capacity=4], (1 -> 0)[capacity=4], (2 -> 4)[capacity=9], (2 -> 3)[capacity=7], (3 -> 4)[capacity=4], (3 -> 0)[capacity=5], (4 -> 1)[capacity=4], (4 -> 2)[capacity=4], (4 -> 0)[capacity=2], (5 -> 0)[capacity=3]];\n",
    "}\n",
    "```\n",
    "Task definition: calculate the maximum flow between two nodes in the graph.\n",
    "Q: What is the maximum flow from node 2 to node 4?\n",
    "A:\n",
    "\n",
    "```\n",
    "max_flow_from_node_2_to_node_4 = 12\n",
    "```\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "f103e7e9-2bf3-4e4d-88d7-433caa19e6b9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are a good graph generator. You need to understand the task definition and generate a graph language to answer the question. \n",
      "Note: (i <-> j) means that node i and node j are connected with an undirected edge. (i -> j) means that node i and node j are connected with a directed edge. \n",
      "Task definition: given a graph structure description, generate a corresponding graph language. Note that: 1) the graph is an undirected-weighted graph and the name is \"undirected-weighted-graph\". 2) The generated graph language should be a code-like structure, and the skeleton format can be expressed as the following:\n",
      "```\n",
      "Graph[name=\"undirected-weighted-graph\"] {\n",
      "    entity_list = [\"xxx\", ...];\n",
      "    triple_list = [(\"xxx\" <-> \"xxx\")[weight=\"xxx\"], ...];\n",
      "}\n",
      "```\n",
      "Q: Given you a graph structure description, please generate a corresponding graph.\n",
      "Description: \"In an undirected graph, the nodes are numbered from 0 to 4, and the edges are:\n",
      "an edge between node 0 and node 2 with weight 20,\n",
      "an edge between node 0 and node 3 with weight 1,\n",
      "an edge between node 1 and node 2 with weight 2,\n",
      "an edge between node 1 and node 4 with weight 2,\n",
      "an edge between node 1 and node 3 with weight 3,\n",
      "an edge between node 2 and node 4 with weight -3,\n",
      "an edge between node 2 and node 3 with weight 3.4,\n",
      "an edge between node 3 and node 4 with weight 1.\".\n",
      "A: \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[1,\n",
       " 887,\n",
       " 526,\n",
       " 263,\n",
       " 1781,\n",
       " 3983,\n",
       " 15299,\n",
       " 29889,\n",
       " 887,\n",
       " 817,\n",
       " 304,\n",
       " 2274,\n",
       " 278,\n",
       " 3414,\n",
       " 5023,\n",
       " 322,\n",
       " 5706,\n",
       " 263,\n",
       " 3983,\n",
       " 4086,\n",
       " 304,\n",
       " 1234,\n",
       " 278,\n",
       " 1139,\n",
       " 29889,\n",
       " 29871,\n",
       " 13,\n",
       " 9842,\n",
       " 29901,\n",
       " 313,\n",
       " 29875,\n",
       " 529,\n",
       " 976,\n",
       " 432,\n",
       " 29897,\n",
       " 2794,\n",
       " 393,\n",
       " 2943,\n",
       " 474,\n",
       " 322,\n",
       " 2943,\n",
       " 432,\n",
       " 526,\n",
       " 6631,\n",
       " 411,\n",
       " 385,\n",
       " 563,\n",
       " 1088,\n",
       " 287,\n",
       " 7636,\n",
       " 29889,\n",
       " 313,\n",
       " 29875,\n",
       " 1599,\n",
       " 432,\n",
       " 29897,\n",
       " 2794,\n",
       " 393,\n",
       " 2943,\n",
       " 474,\n",
       " 322,\n",
       " 2943,\n",
       " 432,\n",
       " 526,\n",
       " 6631,\n",
       " 411,\n",
       " 263,\n",
       " 10624,\n",
       " 7636,\n",
       " 29889,\n",
       " 29871,\n",
       " 13,\n",
       " 5398,\n",
       " 5023,\n",
       " 29901,\n",
       " 2183,\n",
       " 263,\n",
       " 3983,\n",
       " 3829,\n",
       " 6139,\n",
       " 29892,\n",
       " 5706,\n",
       " 263,\n",
       " 6590,\n",
       " 3983,\n",
       " 4086,\n",
       " 29889,\n",
       " 3940,\n",
       " 393,\n",
       " 29901,\n",
       " 29871,\n",
       " 29896,\n",
       " 29897,\n",
       " 278,\n",
       " 3983,\n",
       " 338,\n",
       " 385,\n",
       " 563,\n",
       " 1088,\n",
       " 287,\n",
       " 29899,\n",
       " 7915,\n",
       " 287,\n",
       " 3983,\n",
       " 322,\n",
       " 278,\n",
       " 1024,\n",
       " 338,\n",
       " 376,\n",
       " 870,\n",
       " 1088,\n",
       " 287,\n",
       " 29899,\n",
       " 7915,\n",
       " 287,\n",
       " 29899,\n",
       " 4262,\n",
       " 1642,\n",
       " 29871,\n",
       " 29906,\n",
       " 29897,\n",
       " 450,\n",
       " 5759,\n",
       " 3983,\n",
       " 4086,\n",
       " 881,\n",
       " 367,\n",
       " 263,\n",
       " 775,\n",
       " 29899,\n",
       " 4561,\n",
       " 3829,\n",
       " 29892,\n",
       " 322,\n",
       " 278,\n",
       " 18109,\n",
       " 11285,\n",
       " 3402,\n",
       " 508,\n",
       " 367,\n",
       " 13384,\n",
       " 408,\n",
       " 278,\n",
       " 1494,\n",
       " 29901,\n",
       " 13,\n",
       " 28956,\n",
       " 13,\n",
       " 9527,\n",
       " 29961,\n",
       " 978,\n",
       " 543,\n",
       " 870,\n",
       " 1088,\n",
       " 287,\n",
       " 29899,\n",
       " 7915,\n",
       " 287,\n",
       " 29899,\n",
       " 4262,\n",
       " 3108,\n",
       " 426,\n",
       " 13,\n",
       " 1678,\n",
       " 7855,\n",
       " 29918,\n",
       " 1761,\n",
       " 353,\n",
       " 6796,\n",
       " 12353,\n",
       " 613,\n",
       " 2023,\n",
       " 1385,\n",
       " 13,\n",
       " 1678,\n",
       " 21954,\n",
       " 29918,\n",
       " 1761,\n",
       " 353,\n",
       " 518,\n",
       " 703,\n",
       " 12353,\n",
       " 29908,\n",
       " 529,\n",
       " 976,\n",
       " 376,\n",
       " 12353,\n",
       " 1159,\n",
       " 29961,\n",
       " 7915,\n",
       " 543,\n",
       " 12353,\n",
       " 12436,\n",
       " 2023,\n",
       " 1385,\n",
       " 13,\n",
       " 29913,\n",
       " 13,\n",
       " 28956,\n",
       " 13,\n",
       " 29984,\n",
       " 29901,\n",
       " 11221,\n",
       " 366,\n",
       " 263,\n",
       " 3983,\n",
       " 3829,\n",
       " 6139,\n",
       " 29892,\n",
       " 3113,\n",
       " 5706,\n",
       " 263,\n",
       " 6590,\n",
       " 3983,\n",
       " 29889,\n",
       " 13,\n",
       " 9868,\n",
       " 29901,\n",
       " 376,\n",
       " 797,\n",
       " 385,\n",
       " 563,\n",
       " 1088,\n",
       " 287,\n",
       " 3983,\n",
       " 29892,\n",
       " 278,\n",
       " 7573,\n",
       " 526,\n",
       " 1353,\n",
       " 287,\n",
       " 515,\n",
       " 29871,\n",
       " 29900,\n",
       " 304,\n",
       " 29871,\n",
       " 29946,\n",
       " 29892,\n",
       " 322,\n",
       " 278,\n",
       " 12770,\n",
       " 526,\n",
       " 29901,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29900,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29906,\n",
       " 29900,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29900,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29896,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29896,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29906,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29896,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29946,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29906,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29896,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29941,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29946,\n",
       " 411,\n",
       " 7688,\n",
       " 448,\n",
       " 29941,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29906,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29941,\n",
       " 29889,\n",
       " 29946,\n",
       " 29892,\n",
       " 13,\n",
       " 273,\n",
       " 7636,\n",
       " 1546,\n",
       " 2943,\n",
       " 29871,\n",
       " 29941,\n",
       " 322,\n",
       " 2943,\n",
       " 29871,\n",
       " 29946,\n",
       " 411,\n",
       " 7688,\n",
       " 29871,\n",
       " 29896,\n",
       " 1213,\n",
       " 29889,\n",
       " 13,\n",
       " 29909,\n",
       " 29901,\n",
       " 29871]"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(instruction)\n",
    "a = list(tokenizer.encode(instruction))\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "522c76ea-bbf0-475b-ade7-ceef8674233d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<s>', '▁A', ':', 'The', '▁thought', '▁graph', '▁that', '▁express', 'es']"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens([1, 319, 29901, 1576, 2714, 3983, 393, 4653, 267])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "a2d1e50c-70a2-4eb1-813f-1c010752e544",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_tokens_to_ids(\"</s>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "id": "689718b6-02b4-43bf-9431-d2adad2995b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e9d605e-8029-44c7-bb66-0a245fd23d43",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "1a75da00-cf2f-4346-b988-899e51fbc68d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 319, 29901, 1576, 2714, 3983, 393, 4653, 267, 278, 24481, 10757, 338, 7521, 13, 9527, 29961, 978, 543, 23147, 292, 29899, 386, 1774, 29899, 4262, 3108, 426, 13, 1678, 2943, 29918, 1761, 353, 6796, 25719, 613, 376, 8835, 4982, 613, 376, 1287, 613, 376, 9057, 613, 376, 8835, 10370, 13, 1678, 7636, 29918, 1761, 353, 518, 703, 25719, 29908, 1599, 376, 8835, 4982, 29908, 29961, 5030, 5946, 543, 29874, 326, 3108, 511, 4852, 1287, 29908, 1599, 376, 9057, 29908, 29961, 5030, 5946, 543, 6689, 310, 2125, 2058, 3108, 511, 4852, 9057, 29908, 1599, 376, 8835, 29908, 29961, 5030, 5946, 543, 28111, 20068, 1385, 13, 29913, 13, 16159, 1412, 29871, 13, 1576, 1234, 881, 367, 278, 7306, 310, 2305, 472, 664, 29889, 4587, 278, 2038, 19995, 29892, 278, 7306, 310, 2305, 472, 664, 338, 304, 4866, 278, 4982, 29889, 29871, 13, 6295, 278, 1234, 338, 319, 29889, 29871, 2]\n"
     ]
    }
   ],
   "source": [
    "import copy\n",
    "IGNORE_INDEX = -100\n",
    "instruction = \"\"\"A:\"\"\"\n",
    "\n",
    "answer = \"\"\"The thought graph that expresses the reasoning evidence is ```\\nGraph[name=\"reasoning-thought-graph\"] {\\n    node_list = [\"people\", \"complete job\", \"work\", \"job\", \"complete\"];\\n    edge_list = [(\"people\" -> \"complete job\"[capacity=\"aim\"]), (\"work\" -> \"job\"[capacity=\"place of take place\"]), (\"job\" -> \"complete\"[capacity=\"goal\"])];\\n}\\n```. \\nThe answer should be the goal of people at work. Of the above choices, the goal of people at work is to complete the job. \\nSo the answer is A. \"\"\"\n",
    "\n",
    "\n",
    "prompt = instruction\n",
    "example = prompt + answer\n",
    "\n",
    "prompt = tokenizer.encode(prompt)\n",
    "example = tokenizer.encode(example) # list\n",
    "\n",
    "if len(example) > 2048 - 1:\n",
    "    example = example[:2047]\n",
    "\n",
    "example.append(tokenizer.eos_token_id)\n",
    "print(example)\n",
    "prompt = torch.tensor(\n",
    "    prompt, dtype=torch.int64\n",
    ")\n",
    "\n",
    "example = torch.tensor(\n",
    "    example, dtype=torch.int64\n",
    ")\n",
    "\n",
    "labels = copy.deepcopy(example)\n",
    "labels[: len(prompt)] = -1\n",
    "example_mask = example.ge(0)\n",
    "label_mask = labels.ge(0)\n",
    "example[~example_mask] = 0\n",
    "labels[~label_mask] = IGNORE_INDEX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "fdce3f03-595d-4b04-b7d5-0fbdac2cc37d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ -100,  -100,  -100,  1576,  2714,  3983,   393,  4653,   267,   278,\n",
      "        24481, 10757,   338,  7521,    13,  9527, 29961,   978,   543, 23147,\n",
      "          292, 29899,   386,  1774, 29899,  4262,  3108,   426,    13,  1678,\n",
      "         2943, 29918,  1761,   353,  6796, 25719,   613,   376,  8835,  4982,\n",
      "          613,   376,  1287,   613,   376,  9057,   613,   376,  8835, 10370,\n",
      "           13,  1678,  7636, 29918,  1761,   353,   518,   703, 25719, 29908,\n",
      "         1599,   376,  8835,  4982, 29908, 29961,  5030,  5946,   543, 29874,\n",
      "          326,  3108,   511,  4852,  1287, 29908,  1599,   376,  9057, 29908,\n",
      "        29961,  5030,  5946,   543,  6689,   310,  2125,  2058,  3108,   511,\n",
      "         4852,  9057, 29908,  1599,   376,  8835, 29908, 29961,  5030,  5946,\n",
      "          543, 28111, 20068,  1385,    13, 29913,    13, 16159,  1412, 29871,\n",
      "           13,  1576,  1234,   881,   367,   278,  7306,   310,  2305,   472,\n",
      "          664, 29889,  4587,   278,  2038, 19995, 29892,   278,  7306,   310,\n",
      "         2305,   472,   664,   338,   304,  4866,   278,  4982, 29889, 29871,\n",
      "           13,  6295,   278,  1234,   338,   319, 29889, 29871,     2])\n"
     ]
    }
   ],
   "source": [
    "print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a33d04b0-6e40-4600-8e65-fc0caaa04986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1, 2, 3],\n",
       " [1, -2, 3],\n",
       " [4, 4, 5],\n",
       " [1, 2, 3],\n",
       " [1, -2, 3],\n",
       " [4, 4, 5],\n",
       " [1, 2, 3],\n",
       " [1, -2, 3],\n",
       " [4, 4, 5]]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[[1, 2, 3], [1, -2, 3], [4, 4, 5]]*3"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
